import torch.nn as nn

"""
    Defining class for the Convolutional Neural Network.
"""
class ArtNet(nn.Module):
    def __init__(self, num_classes):
        super(ArtNet, self).__init__()

        # Layer 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, padding="same"),
            nn.BatchNorm2d(num_features=12), nn.ReLU()
        )
        # Layer 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=12, out_channels=20, kernel_size=3, padding="same"),
            nn.BatchNorm2d(num_features=20), nn.ReLU()
        )
        # Layer 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(in_channels=20, out_channels=32, kernel_size=3, padding="same"),
            nn.BatchNorm2d(num_features=32), nn.ReLU()
        )
        # Pooling Layer
        self.pool = nn.MaxPool2d(kernel_size=2)

        # Dropout
        self.dropout = nn.Dropout(p=0.2)

        # Fully Connected Layers
        self.fc1 = nn.Linear(in_features=16*16*32, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=num_classes)


    def forward(self, input):
        output = self.layer1(input)
        output = self.pool(output)
        output = self.dropout(output)
        output = self.layer2(output)
        output = self.pool(output)
        output = self.dropout(output)
        output = self.layer3(output)
        output = self.pool(output)

        output = output.view(-1, 16*16*32)
        output = self.fc1(output)
        output = self.fc2(output)

        return output
